import os
import os.path as osp
import glob
from typing import Tuple

import gym
import clip
import wandb
from tqdm import tqdm
import numpy as np

from diffgro.experiments.skill_diffuser import SkillDiffuser, SkillDiffuserPolicy
from diffgro.common.buffers import load_traj, MTTrajectoryBuffer
from diffgro.utils.config import load_config, save_config
from diffgro.utils import Parser, make_dir, print_r, print_y, print_b
from diffgro.environments.collect_dataset import get_skill_embed
from diffgro.utils import make_dir, save_video, write_annotation


lm, _ = clip.load("ViT-B/16")
task_detail = {
    "window-close-variant-v2": "push and close a window",
    "window-open-variant-v2": "push and open a window",
    "door-open-variant-v2": "open a door with a revolving joint",
    "peg-insert-side-variant-v2": "insert a peg sideways to the goal point",
    "drawer-open-variant-v2": "open a drawer",
    "pick-place-variant-v2": "pick a puck, and place the puck to the goal",
    "reach-variant-v2": "reach the goal point",
    "button-press-variant-v2": "press the button from the front",
    "push-variant-v2": "push the puck to the goal point",
    "drawer-close-variant-v2": "push and close a drawer",
}


def make_env(args):
    print_r(f"<< Making Environment for {args.env_name}... >>")
    domain_name, env_name = args.env_name.split(".")
    env = gym.make(env_name, seed=args.seed)
    print_y(
        f"Obs Space: {env.observation_space.shape}, Act Space: {env.action_space.shape}"
    )

    from diffgro.environments.variant import Categorical

    env.variant_space.variant_config["goal_resistance"] = Categorical(
        a=[args.goal_resistance]
    )
    return env, domain_name, env_name


def make_buff(args, env):
    domain_name, _ = args.env_name.split(".")
    buff = MTTrajectoryBuffer(100, False, env.observation_space, env.action_space)

    dataset_path = osp.join(args.dataset_path, domain_name)
    task_paths = glob.glob(osp.join(dataset_path, "*"))
    task_list = [path.split("/")[-1] for path in task_paths]

    task_emb = [get_skill_embed(lm, task_detail[task]) for task in task_list]

    print_r(f"<< Making Buffer for {domain_name}... >>")
    for task_path, task in zip(task_paths, task_emb):
        print_b(f"Adding task at {task_path}")
        traj_paths = glob.glob(osp.join(task_path, "trajectory", "*.pkl"))
        traj = [load_traj(path) for path in traj_paths]

        for t in traj:
            t["task"] = task

        buff.add_task(traj)

    return buff, task_list


def evaluate(
    policy,
    env: gym.Env,
    domain_name: str,
    env_name: str,
    n_episodes: int = 10,
    video: bool = False,
    save_path: str = None,
) -> Tuple[np.array, np.array]:
    tot_success, tot_length, frames = [], [], []

    pbar = tqdm(total=n_episodes)
    lang = get_skill_embed(lm, task_detail[env_name])

    for episode in range(n_episodes):
        obs, done, step = env.reset(), False, 0
        if getattr(policy, "reset", None) is not None:
            policy.reset()

        frames = []
        while not done:
            action = policy.predict(obs, lang)
            action = np.array(action.copy())
            obs, _, done, e_info = env.step(action)
            step += 1

            if done:
                break

            if video:
                frame = env.render()
                frames.append(frame)

        tot_success.append(e_info["success"])
        tot_length.append(step)

        pbar.update(1)
        pbar.set_description(f"Episodes: {episode + 1}/{n_episodes}")

        if video:
            video_folder = osp.join(save_path)
            make_dir(video_folder)
            video_path = osp.join(video_folder, f"{env_name}_{episode}.mp4")
            save_video(video_path, frames)

    avg_success, std_success = (
        np.mean(tot_success, axis=0) * 100,
        np.std(tot_success, axis=0) * 100,
    )
    avg_length, std_length = np.mean(tot_length, axis=0), np.std(tot_length, axis=0)

    print_b("=" * 13 + f" Performance Evaluation " + "=" * 13)
    print_r(f"\t{domain_name}.{env_name}")
    print(f"\tTotal Length: {avg_length} +\- {std_length}")
    print(f"\tTotal Success Rate : {avg_success} +\- {std_success}")
    print_b("=" * 50)

    # save text file
    with open(os.path.join(save_path, "evaluation.txt"), "a") as f:
        f.write("=" * 13 + f" Performance Evaluation" + "=" * 13 + "\n")
        f.write(f"\t{domain_name}.{env_name}\n")
        f.write(f"\tTotal Length: {avg_length} +\- {std_length}\n")
        f.write(f"\tTotal Success Rate : {avg_success} +\- {std_success}\n")
        f.write("=" * 50 + "\n")


def train(args):
    # 0. Make Environment
    env, domain_name, env_name = make_env(args)

    # 1. Make Buffer
    buff, task_list = make_buff(args, env)
    num_task = len(task_list)
    print_r(f"Number of tasks {num_task}")

    # 2. Save Path
    save_path = osp.join("./results/skill_diffuser", domain_name, env_name, args.tag)

    # 3. Load Config
    config = load_config("./config/experiments/skill_diffuser.yml")

    # 4. Make Models
    model_path = save_path + "/planner"
    model = SkillDiffuser(
        config["planner"]["params"]["policy"],
        env=env,
        replay_buffer=buff,
        batch_size=config["planner"]["params"]["batch_size"] * num_task,
        policy_kwargs=config["planner"]["policy_kwargs"],
    )

    # 4. Training & Evaluation
    if args.train:
        wandb.init(
            project="diffgro",
            tags="skill_diffuser",
        )

        make_dir(save_path)
        save_config(save_path, config)  # save configs
        model.learn(**config["planner"]["training"])
        model.save(path=model_path)
    if args.test:
        model.load(save_path + "/planner")

        for task in task_list:
            args.env_name = "metaworld." + task
            env, domain_name, env_name = make_env(args)

            evaluate(
                model,
                env,
                domain_name,
                env_name,
                args.n_episodes,
                args.video,
                save_path,
            )


if __name__ == "__main__":
    args = Parser("train").parse_args()
    train(args)
